#include <torch/torch.h>

torch::Tensor gelu_fwd(torch::Tensor input) {
    
    return input * torch::sigmoid(1.702*input);
}

torch::Tensor gelu_bwd(torch::Tensor grad_out, torch::Tensor input){
    auto tmp = torch::sigmoid(1.702*input);
    return grad_out * (tmp+1.702*input*tmp*(1.0-tmp));
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward",  &gelu_fwd, "GELU forward"); 
    m.def("backward", &gelu_bwd, "GELU backward");
}